import logging
import os

import torch
from contextlib import nullcontext
from torch.distributed import _coalescing_manager

from megatron.core.distributed.param_and_grad_buffer import _ParamAndGradBucketGroup
import megatron.core.parallel_state as parallel_state
from megatron.core.utils import is_torch_min_version

from .epx_sync_tensor import epx_sync_tensor_across_replicas

logger = logging.getLogger(__name__)

if is_torch_min_version("1.13.0"):
    dist_all_gather_func = torch.distributed.all_gather_into_tensor
    dist_reduce_scatter_func = torch.distributed.reduce_scatter_tensor
else:
    dist_all_gather_func = torch.distributed._all_gather_base
    dist_reduce_scatter_func = torch.distributed._reduce_scatter_base

def shard_buffer(buffer: torch.Tensor, data_parallel_world_size: int):
    """
    Shard buffer into data_parallel_world_size chunks of equal size.
    """
    assert buffer.numel() % data_parallel_world_size == 0
    shard_size = buffer.numel() // data_parallel_world_size
    sharded_buffer = [
        buffer[(r * shard_size) : ((r + 1) * shard_size)] for r in range(data_parallel_world_size)
    ]
    return sharded_buffer

def start_grad_sync(self):
    """
    Initiates grad sync (all-reduce or reduce-scatter) communication operations
    for all buckets in the bucket group.

    When ddp_config.overlap_grad_reduce is set to True, dispatches an asynchronous
    communication call. When ddp_config.overlap_grad_reduce is set to False, makes
    synchronous call.
    """
    assert (
        self.grad_reduce_handle is None
    ), 'Should not have multiple communication calls outstanding at once'
    #print('before')

    if self.ddp_config.check_for_nan_in_grad or self.ddp_config.check_for_large_grads:
        self.check_grads(
            check_for_nan_or_inf=self.ddp_config.check_for_nan_in_grad,
            check_for_large=self.ddp_config.check_for_large_grads,
        )

    # gradient_scaling_factor already takes into account whether we are computing
    # an average or sum in the data-parallel collective.
    for bucket in self.buckets:
        if bucket.gradient_scaling_factor != 1.0:
            bucket.grad_data *= bucket.gradient_scaling_factor

    # Decide reduce_op.
    reduce_op = torch.distributed.ReduceOp.SUM
    if self.ddp_config.average_in_collective:
        reduce_op = torch.distributed.ReduceOp.AVG

    # We use the following stream synchronization for the gradient reduction
    # within and across DistOpt instances.

    # Compute Stream: -------------Gradient compute-------------------
    # Comm. Stream:   ------(wait for NCCL)-----(wait for NCCL)-------
    # NCCL Stream:          -------RS------     -------AR------

    # Use async communications only when overlap_grad_reduce is True.
    async_op = (
        self.ddp_config.overlap_grad_reduce
        and self.ddp_config.num_distributed_optimizer_instances == 1
    )
    if (
        self.ddp_config.num_distributed_optimizer_instances > 1
        and self.ddp_config.overlap_grad_reduce
    ):
        # Assign a communication stream if we have multiple DistOpt instances and we
        # need to overlap communication.
        stream_context = torch.cuda.stream(self.communication_stream)

        # The RS/AR communication stream needs to wait for the default stream
        # to complete its gradient computation before launching the next
        # gradient reduction collective.
        self.communication_stream.wait_stream(torch.cuda.default_stream())
    else:
        stream_context = nullcontext()

    if self.ddp_config.use_distributed_optimizer:
        communication_group = self.intra_distributed_optimizer_instance_group
    else:
        communication_group = self.data_parallel_group

    # Coalesce communication kernels across buckets in the bucket group.
    with stream_context, _coalescing_manager(communication_group, async_ops=async_op) as cm:
        for bucket in self.buckets:
            if self.ddp_config.use_distributed_optimizer:
                local_data_view = shard_buffer(
                    bucket.grad_data, self.intra_distributed_optimizer_instance_size
                )[self.intra_distributed_optimizer_instance_rank]

                dist_reduce_scatter_func(
                    local_data_view,
                    bucket.grad_data,
                    op=reduce_op,
                    group=communication_group,
                    async_op=async_op,
                )

                if int(os.getenv("USE_EPX", 0)) and not async_op:
                    epx_sync_tensor_across_replicas(local_data_view, opts=torch.distributed.ReduceOp.AVG)
            else:
                torch.distributed.all_reduce(
                    bucket.grad_data, op=reduce_op, group=communication_group, async_op=async_op
                )
                if int(os.getenv("USE_EPX", 0)) and not async_op:
                    epx_sync_tensor_across_replicas(bucket.grad_data, opts=torch.distributed.ReduceOp.AVG)

    # print('before before allreduce')
    # With multiple DistOpt instances, we need to all-reduce across instances.
    if (
        self.ddp_config.use_distributed_optimizer
        and self.ddp_config.num_distributed_optimizer_instances > 1
    ):
        # Create a new coalescing manager for the inter-instance all-reduce.
        with stream_context, _coalescing_manager(
            self.inter_distributed_optimizer_instance_group, async_ops=async_op
        ) as cm:
            for bucket in self.buckets:
                local_data_view = shard_buffer(
                    bucket.grad_data, self.intra_distributed_optimizer_instance_size
                )[self.intra_distributed_optimizer_instance_rank]
                # print('before all reduce')
                torch.distributed.all_reduce(
                    local_data_view,
                    op=reduce_op,
                    group=self.inter_distributed_optimizer_instance_group,
                    async_op=async_op,
                )
                # print('after all reduce')
    # print('after after allreduce')

    if async_op:
        self.grad_reduce_handle = cm
    else:
        # When using `_coalescing_manager`, even if a synchronous op (async_op=False) is used,
        # `cm` is not None, which is different from when `_coalescing_manager` is not used in
        # which case the torch.distributed._reduce_scatter_base() will return None. In order to
        # maintain consistency with prior code, we need to manually set communication handle to
        # None.
        self.grad_reduce_handle = None

def finish_grad_sync(self):
    """
    Finishes grad sync (all-reduce or reduce-scatter) communication operations
    for all buckets in the bucket group.

    When ddp_config.overlap_grad_reduce is set to True, waits for asynchronous
    communication call to complete. When ddp_config.overlap_grad_reduce is set to False,
    makes synchronous call.
    """
    self.param_gather_dispatched = False

    # If overlap_grad_reduce is False, start (and finish) synchronous communication call here.
    # print(f'before self.ddp_config.overlap_grad_reduce is {self.ddp_config.overlap_grad_reduce}')
    if not self.ddp_config.overlap_grad_reduce:
        self.start_grad_sync()
        return
    # print(f'after self.ddp_config.overlap_grad_reduce is {self.ddp_config.overlap_grad_reduce}')
    # When using multiple DistOpt instances, we don't need to sync here as we launch
    # communications on a separate communication stream.
    if self.ddp_config.num_distributed_optimizer_instances > 1:
        torch.cuda.default_stream().wait_stream(self.communication_stream)
        return
    assert self.grad_reduce_handle is not None, (
        f'Communication call has not been issued for this bucket '
        f'({len(self.params_with_grad)}/{len(self.params)} params have grad available)'
    )
    self.grad_reduce_handle.wait()
    self.grad_reduce_handle = None

    # TODO: Using `_coalescing_manager` to optimize code structure.
    if int(os.getenv("USE_EPX", 0)):
        for bucket in self.buckets:
            if self.ddp_config.use_distributed_optimizer:
                local_data_view = shard_buffer(
                    bucket.grad_data, self.intra_distributed_optimizer_instance_size
                )[self.intra_distributed_optimizer_instance_rank]
                if int(os.getenv("USE_EPX", 0)):
                    epx_sync_tensor_across_replicas(local_data_view, opts=torch.distributed.ReduceOp.AVG)
            else:
                if int(os.getenv("USE_EPX", 0)):
                    epx_sync_tensor_across_replicas(bucket.grad_data, opts=torch.distributed.ReduceOp.AVG)

_ParamAndGradBucketGroup.start_grad_sync = start_grad_sync
_ParamAndGradBucketGroup.finish_grad_sync = finish_grad_sync
